import os
import sys

os.chdir(sys.path[0])
sys.path.append("../../")
os.getcwd()

import logging

import numpy as np
import torch
import torch.nn as nn
from torch import FloatTensor
from torch.utils.data import DataLoader
from tqdm import tqdm


from utils.aggregate_block.train_settings_generate import (
    argparser_criterion,
    argparser_opt_scheduler,
)
from utils.trainer_cls import BackdoorModelTrainer
from utils.aggregate_block.model_trainer_generate import generate_cls_model


class SelectionTrainer(BackdoorModelTrainer):
    def __init__(self, args, device):
        self.args = args
        self.device = device

    def set_optimizer_scheduler(self, optimizer=None, scheduler=None):
        if optimizer is not None:
            self.optimizer = optimizer
        elif scheduler is not None:
            self.scheduler = scheduler
        else:
            self.optimizer, self.scheduler = argparser_opt_scheduler(self.model, self.args)

    def set_optimizer(self, optimizer=None):
        if optimizer is not None:
            self.optimizer = optimizer

    def set_scheduler(self, scheduler=None):
        if scheduler is not None:
            self.scheduler = scheduler

    def set_model(self, model=None):
        if model is not None:
            self.model = model
        else:
            self.model = generate_cls_model(
                model_name=self.args.model, num_classes=self.args.num_classes, image_size=self.args.img_size[0]
            )
        self.model.to(self.device)

    # 在自定义的数据集上训练，先将clean train data和backdoor train data拼接，然后forward，再将结果根据poison_idx混合，得到loss
    def train_one_epoch_concat(self, mix_bd_train_dataloader, mask):
        self.model.train()
        self.criterion = nn.CrossEntropyLoss(reduction="none")
        all_bd_loss = torch.zeros(len(mix_bd_train_dataloader.dataset))
        all_cl_loss = torch.zeros(len(mix_bd_train_dataloader.dataset))
        all_bd_preds = torch.zeros(len(mix_bd_train_dataloader.dataset))
        all_cl_preds = torch.zeros(len(mix_bd_train_dataloader.dataset))
        all_bd_labels = torch.zeros(len(mix_bd_train_dataloader.dataset))
        all_cl_labels = torch.zeros(len(mix_bd_train_dataloader.dataset))

        loss_list = []

        for (
            batch_idx,
            (
                bd_imgs,
                bd_labels,
                cl_imgs,
                cl_labels,
                original_idxs,
            ),
        ) in enumerate(mix_bd_train_dataloader):
            (
                bd_imgs,
                bd_labels,
                cl_imgs,
                cl_labels,
            ) = (
                bd_imgs.to(self.device),
                bd_labels.to(self.device),
                cl_imgs.to(self.device),
                cl_labels.to(self.device),
            )

            poison_idxs = FloatTensor(mask[original_idxs]).to(self.device)

            batch_imgs = torch.cat((bd_imgs, cl_imgs), 0)
            batch_preds = self.model(batch_imgs)

            bd_preds = batch_preds[: len(bd_imgs)]
            cl_preds = batch_preds[len(bd_imgs) :]

            batch_bd_loss = self.criterion(bd_preds, bd_labels)
            batch_cl_loss = self.criterion(cl_preds, cl_labels)

            bd_training_loss = (
                torch.dot(poison_idxs, batch_bd_loss) + torch.dot(1 - poison_idxs, batch_cl_loss)
            ) / len(bd_imgs)

            loss = bd_training_loss
            loss_list.append(loss.detach().clone().item())

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            all_bd_loss[original_idxs] = batch_bd_loss.cpu().detach()
            all_cl_loss[original_idxs] = batch_cl_loss.cpu().detach()
            all_bd_preds[original_idxs] = torch.argmax(bd_preds, dim=1).float().detach().cpu()
            all_cl_preds[original_idxs] = torch.argmax(cl_preds, dim=1).float().detach().cpu()
            all_bd_labels[original_idxs] = bd_labels.cpu().detach().float()
            all_cl_labels[original_idxs] = cl_labels.cpu().detach().float()

        if self.scheduler is not None:
            self.scheduler.step()

        epoch_avg_loss = np.mean(loss_list)

        return (
            epoch_avg_loss,
            all_bd_loss,
            all_cl_loss,
            all_bd_preds,
            all_cl_preds,
            all_bd_labels,
            all_cl_labels,
        )

    # 在自定义的数据集上训练，先将clean train data和backdoor train data 根据 poison_idxs 进行混合，然后进行训练
    def train_one_epoch_mix(self, mix_bd_train_dataloader, mask):
        self.model.train()
        self.criterion = nn.CrossEntropyLoss()
        loss_list = []
        num = mask.shape[0]
        epoch_label = torch.zeros(num).to(self.device)
        epoch_predict = torch.zeros(num).to(self.device)
        epoch_orgin_label = torch.zeros(num).to(self.device)
        epoch_origin_idx = torch.zeros(num)
        for batch_idx, item in enumerate(mix_bd_train_dataloader):
            bd_imgs, bd_labels, cl_imgs, cl_labels, original_idxs = item
            bd_imgs = bd_imgs.to(self.device)
            bd_labels = bd_labels.to(self.device)
            cl_imgs = cl_imgs.to(self.device)
            cl_labels = cl_labels.to(self.device)
            poison_idxs = torch.tensor(mask[original_idxs]).to(self.device)

            bd_idx = torch.where(poison_idxs == 1)[0]
            cl_idx = torch.where(poison_idxs == 0)[0]

            batch_imgs = torch.zeros_like(bd_imgs).to(self.device)
            batch_imgs[bd_idx] = bd_imgs[bd_idx]
            batch_imgs[cl_idx] = cl_imgs[cl_idx]

            batch_label = torch.zeros_like(bd_labels).to(self.device)
            batch_label[bd_idx] = bd_labels[bd_idx]
            batch_label[cl_idx] = cl_labels[cl_idx]

            batch_probs = self.model(batch_imgs)
            batch_loss = self.criterion(batch_probs, batch_label)

            batch_predict = torch.max(batch_probs, -1)[1].detach().clone()

            self.optimizer.zero_grad()
            batch_loss.backward()
            self.optimizer.step()

            loss_list.append(batch_loss)
            epoch_label[original_idxs] = batch_label.float()
            epoch_orgin_label[original_idxs] = cl_labels.float()
            epoch_predict[original_idxs] = batch_predict.float()
            epoch_origin_idx[original_idxs] = original_idxs.float()

        if self.scheduler is not None:
            self.scheduler.step()

        one_epoch_loss = sum(loss_list) / len(loss_list)

        bd_idx = np.where(mask == 1)[0]
        cl_idx = np.where(mask == 0)[0]

        train_clean_acc = self.all_acc(
            epoch_predict[cl_idx],
            epoch_label[cl_idx],
        )
        train_asr = self.all_acc(
            epoch_predict[bd_idx],
            epoch_label[bd_idx],
        )
        train_ra = self.all_acc(
            epoch_predict[bd_idx],
            epoch_orgin_label[bd_idx],
        )
        train_mix_acc = self.all_acc(
            epoch_predict,
            epoch_label,
        )

        return (
            one_epoch_loss.cpu().detach().item(),
            train_mix_acc,
            train_asr,
            train_clean_acc,
            train_ra,
            epoch_label,
            epoch_orgin_label,
            epoch_predict,
            epoch_origin_idx,
        )

    def train_one_epoch_normal(self, normal_bd_train_dataloader):
        self.model.train()
        self.criterion = nn.CrossEntropyLoss()
        batch_loss_list = []
        batch_predict_list = []
        batch_label_list = []
        batch_original_index_list = []
        batch_poison_indicator_list = []
        batch_original_labels_list = []

        for batch_idx, item in enumerate(normal_bd_train_dataloader):
            (
                imgs,
                labels,
                original_index,
                poison_indicator,
                original_labels,
            ) = item
            imgs = imgs.to(self.device)
            labels = labels.to(self.device)
            batch_probs = self.model(imgs)
            batch_loss = self.criterion(batch_probs, labels)
            batch_predict = torch.max(batch_probs, -1)[1].detach().clone()

            self.optimizer.zero_grad()
            batch_loss.backward()
            self.optimizer.step()

            batch_loss_list.append(batch_loss)
            batch_predict_list.append(batch_predict.detach().clone().cpu())
            batch_label_list.append(labels.detach().clone().cpu())
            batch_original_index_list.append(original_index.detach().clone().cpu())
            batch_poison_indicator_list.append(poison_indicator.detach().clone().cpu())
            batch_original_labels_list.append(original_labels.detach().clone().cpu())

        if self.scheduler is not None:
            self.scheduler.step()

        one_epoch_loss = sum(batch_loss_list) / len(batch_loss_list)

        epoch_predicts = torch.cat(batch_predict_list)
        epoch_labels = torch.cat(batch_label_list)
        epoch_orginal_idxs = torch.cat(batch_original_index_list)
        epoch_poison_idxs = torch.cat(batch_poison_indicator_list)
        epoch_orignal_labels = torch.cat(batch_original_labels_list)

        epoch_bd_idx = torch.where(epoch_poison_idxs == 1)[0]
        epoch_clean_idx = torch.where(epoch_poison_idxs == 0)[0]

        train_clean_acc = self.all_acc(
            epoch_predicts[epoch_clean_idx],
            epoch_labels[epoch_clean_idx],
        )
        train_asr = self.all_acc(
            epoch_predicts[epoch_bd_idx],
            epoch_labels[epoch_bd_idx],
        )
        train_ra = self.all_acc(
            epoch_predicts[epoch_bd_idx],
            epoch_orignal_labels[epoch_bd_idx],
        )
        train_mix_acc = self.all_acc(
            epoch_predicts,
            epoch_labels,
        )

        return (
            one_epoch_loss.cpu().detach().item(),
            train_mix_acc,
            train_asr,
            train_clean_acc,
            train_ra,
            epoch_predicts,
            epoch_labels,
            epoch_orginal_idxs,
            epoch_poison_idxs,
            epoch_orignal_labels,
        )

    @torch.no_grad()
    def test_normal(self, normal_test_dataloader):
        self.model.eval()
        self.criterion = nn.CrossEntropyLoss(reduction="none")
        batch_loss_list = []
        batch_meanloss_list = []
        batch_predict_list = []
        batch_label_list = []
        batch_prob_list = []

        for batch_idx, item in enumerate(normal_test_dataloader):
            imgs, labels = item[0], item[1]
            imgs = imgs.to(self.device)
            labels = labels.to(self.device)
            batch_probs = self.model(imgs)
            batch_loss = self.criterion(batch_probs, labels)
            mean_loss = torch.mean(batch_loss)
            batch_meanloss_list.append(mean_loss)
            batch_predict = torch.max(batch_probs, -1)[1].detach().clone()

            batch_loss_list.append(batch_loss)
            batch_predict_list.append(batch_predict.detach().clone().cpu())
            batch_label_list.append(labels.detach().clone().cpu())
            batch_prob_list.append(batch_probs.detach().clone().cpu())

        all_loss = sum(batch_meanloss_list) / len(batch_meanloss_list)
        epoch_predicts = torch.cat(batch_predict_list)
        epoch_labels = torch.cat(batch_label_list)
        epoch_probs = torch.cat(batch_prob_list)
        epoch_loss = torch.cat(batch_loss_list)

        acc = self.all_acc(epoch_predicts, epoch_labels)

        return (
            all_loss.cpu().detach().item(),
            epoch_loss,
            epoch_predicts,
            epoch_labels,
            epoch_probs,
            acc,
        )

    @torch.no_grad()
    def test_concat(self, mix_bd_dataloader, mask):
        self.model.eval()
        self.criterion = nn.CrossEntropyLoss(reduction="none")
        all_bd_loss = torch.zeros(len(mix_bd_dataloader.dataset))
        all_cl_loss = torch.zeros(len(mix_bd_dataloader.dataset))
        all_bd_preds = torch.zeros(len(mix_bd_dataloader.dataset))
        all_cl_preds = torch.zeros(len(mix_bd_dataloader.dataset))
        all_bd_labels = torch.zeros(len(mix_bd_dataloader.dataset))
        all_cl_labels = torch.zeros(len(mix_bd_dataloader.dataset))

        loss_list = []

        for (
            batch_idx,
            (
                bd_imgs,
                bd_labels,
                cl_imgs,
                cl_labels,
                original_idxs,
            ),
        ) in enumerate(mix_bd_dataloader):
            (
                bd_imgs,
                bd_labels,
                cl_imgs,
                cl_labels,
            ) = (
                bd_imgs.to(self.device),
                bd_labels.to(self.device),
                cl_imgs.to(self.device),
                cl_labels.to(self.device),
            )

            poison_idxs = FloatTensor(mask[original_idxs]).to(self.device)

            batch_imgs = torch.cat((bd_imgs, cl_imgs), 0)
            batch_preds = self.model(batch_imgs)

            bd_preds = batch_preds[: len(bd_imgs)]
            cl_preds = batch_preds[len(bd_imgs) :]

            batch_bd_loss = self.criterion(bd_preds, bd_labels)
            batch_cl_loss = self.criterion(cl_preds, cl_labels)

            bd_training_loss = (
                torch.dot(poison_idxs, batch_bd_loss) + torch.dot(1 - poison_idxs, batch_cl_loss)
            ) / len(bd_imgs)

            loss = bd_training_loss
            loss_list.append(loss.detach().clone().item())

            all_bd_loss[original_idxs] = batch_bd_loss.cpu().detach()
            all_cl_loss[original_idxs] = batch_cl_loss.cpu().detach()
            all_bd_preds[original_idxs] = torch.argmax(bd_preds, dim=1).float().detach().cpu()
            all_cl_preds[original_idxs] = torch.argmax(cl_preds, dim=1).float().detach().cpu()
            all_bd_labels[original_idxs] = bd_labels.cpu().detach().float()
            all_cl_labels[original_idxs] = cl_labels.cpu().detach().float()

        epoch_avg_loss = np.mean(loss_list)

        return (
            epoch_avg_loss,
            all_bd_loss,
            all_cl_loss,
            all_bd_preds,
            all_cl_preds,
            all_bd_labels,
            all_cl_labels,
        )

    def all_acc(
        self,
        preds: torch.Tensor,
        labels: torch.Tensor,
    ):
        if len(preds) == 0 or len(labels) == 0:
            logging.warning("zero len array in func all_acc(), return None!")
            return None
        return preds.eq(labels).sum().item() / len(preds)
